Задание 2: Трехклассовая семантическая сегментация¶

Предлагается решить задачу семантической сегментации животных с тремя классами: класс "фон" (метка 0), класс "кошка" (метка 1) и класс "собака" (метка 2). Image

Для этого сами подготовим датасет, реализуем метрики/функции потерь, реализуем и обучим свою PSPNet-подобную архитектуру.

Загрузка модулей¶

In [1]:
# Загружаем pytorch для работы с нейронными сетями
import torch
import torch.nn as nn
import torch.nn.functional as F

# Для работы с изображениями/графиками
from torchvision import transforms
# Загружаем способы интерполяции изображений
from torchvision.transforms.functional import InterpolationMode as IM
import matplotlib.pyplot as plt

# Для логирования метрик и функций потерь в ходе обучения
from torch.utils.tensorboard import SummaryWriter

# Для удобной работы с обучающей/тестовой выборкой
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# Прочее
import numpy as np
from tqdm.auto import tqdm
/home/r.fazylov/anaconda3/envs/pl_vl_170/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Часть 1: Подготовка данных¶

1.1 Предобработка датасета (1 балл)¶

Для начала работы с данными требуется выполнить следующие пункты:

  • Определиться со способом хранения/чтения данных с диска. В задачах комьютерного зрения датасеты, как правило, имеют большой размер, который не помещается в оперативную память. Поэтому предлагается несколько форматов хранения: HDF5, memory-mapped files и "сырой" вид, т.е хранение .jpg/.png файлов на диске. Классы с указанными способами уже описаны в файле utils.py. Вам предлагается лишь замерить скорость чтения данных для каждого из форматов, затем выбрать наиболее быстрый (чуть позже).

    • Поговорим поподробнее об особенностях этих форматов хранения. Формат hdf5 позволяет разбивать массивы информации на chunks, которые организованы в виде B-деревьев. Это имеет смысл при чтении hyperslabs - многомерных срезов массива, которые несмежны в памяти (non-contiguous). По умолчанию, hdf5 хранит данные непрерывно (contiguous)
    • Memory-mapping файлов в оперативную память позволяет пропустить этап буфферизации, тем самым пропуская операцию копирования, лениво загружая информацию напрямую. Особенность этого подхода в том, что алгоритмически Best case скорости чтения достигается на непрерывном блоке информации (contiguous), а Worst case - наоборот, на несмежном в памяти (non-contiguous) блоке (на порядки хуже, чем потенциально возможно в hdf5).
  • Привести все пары (изображение, маска) к единому размеру target_shape, указанному далее в словаре конфигурации default_config. Предлагается следующая последовательность действий:

    1. При помощи transforms.Resize мы можем интерполировать (по умолчанию билинейная интерполяция) значения пикселей при изменении размера исходного изображения до заданного. Однако, подобная операция искажает исходное соотношение сторон изображения, что может негативно сказаться на предсказательной способности сети. Например, общий вид морды кошки будет зависеть от исходного размера изображения, а не от сущности класса "кошка": оно может быть не растянуто, может быть растянуто вертикально/горизонтально. Неконсистентность в представлении одной и той же сущности может привести к нестабильному обучению, так как размеры ядра свертки едины для любого входного изображения! К счастью, эта проблема уже решена в transforms.Resize: при целочисленном аргументе size наименьшая сторона входного изображения будет интерполирована до size, а другая сторона (наибольшая) до размера size * aspect_ratio, т.е сохраняя соотношение сторон aspect_ratio
    2. На текущий момент лишь одна из сторон исходного изображения соответствует требуемому размеру target_shape. Возможны два случая: оставшаяся сторона меньше или больше (случай с равенством можно свести к ситуации "меньше на 0") требуемого размера. В первом случае будем дополнять изображение пикселями со значением pad_value при помощи transforms.Pad, а во втором - обрезать изображение при помощи transforms.CenterCrop.

Последовательное исполнение операций модуля transforms можно выполнить при помощи transforms.Compose.

  • Ответить на вопрос: А зачем, вообще, требуется сводить все изображения к одному размеру?

Ваш ответ: чтобы стакать их в батчи и хранить батчи как тензор [N, C, W, H]

In [2]:
from utils import *

def resize(img: Image, target_shape: tuple[int, int], pad_val: int) -> np.array:
    """
    Приводит входное изображение `img` к размеру `target_shape`, указанной выше 
    последовательностью действий. Предполагается, что требуемый размер `target_shape` "квадратный"
    """
    # Проверяем равенство желаемых размеров сторон изображения
    assert target_shape[0] == target_shape[1]
    
    # Необходимо для "универсальности" определения количества недостающих padding-пикселей.
    # В случае если размерность текущего изображения меньше требуемой - имеем неотрицательное
    # Количество недостающих padding-пикселей, relu возвращает это число без изменений.
    # В случае если размерность текущего изображения больше требуемой - имеем отрицательное
    # Количество недостающих padding-пикселей, т.е в padding пикселях мы не нуждаемся и 
    # relu возвращает значение 0.
    def relu(x):
        return x * (x > 0)
    
    # Масштабируем наименьшую размерность `img` под `target_shape`
    # В качестве способа интерполяции выберем интерполяцию методом ближайшего соседа
    # Это необходимо для сохранения множества значений маски сегментации
    img = transforms.Resize(target_shape[0], interpolation=IM.NEAREST)(img)
    
    # Вычисляем количество недостающих padding-пикселей для каждой из сторон изображения
    h, w = img.size[0], img.size[1]
    h_diff, w_diff = target_shape[0] - h, target_shape[1] - w

    pad_left, pad_right = w_diff // 2, w_diff - w_diff // 2
    pad_top, pad_bottom = h_diff // 2, h_diff - h_diff // 2
    
    resize_transform = transforms.Compose([
        # Добавляем padding-пиксели. Если их нет, то операция Pad ничего не изменит (случай "больше").
        transforms.Pad(
            (relu(pad_left), relu(pad_top), relu(pad_right), relu(pad_bottom)),
            fill=pad_val, padding_mode="constant"
        ),
        # Обрезаем "лишние" пиксели. Если их нет, то CenterCrop ничего не изменит (случай "меньше").
        transforms.CenterCrop(target_shape),
        
        # Преобразуем PIL.Image изображение в массив np.array
        transforms.Lambda(lambda x: np.array(x))
    ])

    return resize_transform(img)
In [3]:
def prepare_dataset(config: dict, storage_class: Type[storage_class]):
    """
    Предобрабатывает датасет и эффективно его сохраняет на диск
    """
    with open(config["annotation_file"]) as f:
        lines = f.readlines()
        
    # Заводим массивы для блоков изображений, помещаемых в память
    input_chunk = np.empty((config["chunk_size"], *config["target_shape"], 3), dtype=np.uint8)
    target_chunk = np.empty((config["chunk_size"], *config["target_shape"]), dtype=np.uint8)
    
    # Делим датасет на блоки
    config["dataset_size"] = len(lines)
    num_chunks = config["dataset_size"] // config["chunk_size"] + bool(config["dataset_size"] % config["chunk_size"])
    dataset = storage_class(config)
    
    # Читаем изображения с диска, предобрабатываем и сохраняем в выбранный нами формат
    for chunk_idx in tqdm(range(num_chunks)):
        for pos in range(config["chunk_size"]):
            flat_idx = chunk_idx * config["chunk_size"] + pos
            if (flat_idx >= config["dataset_size"]):
                break
                                 
            img_name, label = lines[flat_idx].rstrip("\n").split(' ')
                              
            input_raw = Image.open(os.path.join(config["input_dir"], img_name + ".jpg")).convert("RGB")
            target_raw = Image.open(os.path.join(config["target_dir"], img_name + ".png")).convert('L')

            input_chunk[pos] = resize(input_raw, config["target_shape"], 0)
            target_chunk[pos] = renumerate_target(resize(target_raw, config["target_shape"], 2), int(label))
        dataset.append(input_chunk, target_chunk)
    dataset.lock()
    
    return dataset

Для простоты будем выбирать размер изображений target_shape с одинаковыми сторонами. Предлагается использовать размер 256x256, хотя выбор за вами. Обратите внимание, что от размера изображений зависит быстродействие дальнейшего кода (чем больше картинки, тем дольше обучать).

In [5]:
# Конфигурация датасета
default_config = {
             "input_dir": "SegTask/images",
             "target_dir": "SegTask/seg_masks",
             "target_shape": (256, 256), # Можно любой другой размер картинки
             "chunk_size": 512, # количество изображений в блоке, загружаемых в оперативную память
            }

# Конфигурации обучающей и тестовой выборок отличаются файлов аннотации
config_train = {"annotation_file": "SegTask/trainval.txt"} | default_config
config_test = {"annotation_file": "SegTask/test.txt"} | default_config

train_data_hdf5 = prepare_dataset(config_train, storage_hdf5)
train_data_memmap = prepare_dataset(config_train, storage_memmap)
train_data_raw = prepare_dataset(config_train, storage_raw)
100%|██████████| 8/8 [00:36<00:00,  4.52s/it]
100%|██████████| 8/8 [00:21<00:00,  2.68s/it]
100%|██████████| 8/8 [00:51<00:00,  6.50s/it]

1.2 Создание Dataset и DataLoader (1.5 балла)¶

Pytorch предоставляет нам удобные обертки Dataset и DataLoader для наших данных, которые эффективно нарезают наш датасет на batches (блоки) заданного размера, а также параллелизуют процесс чтения на num_workers нитей.

Также для дальнейшей работы нам понадобится аугментация данных. Ее цель заключается в еще большем расширении обучающей выборки путем применения преобразований над изображениями, которые изменяют их абсолютные значения пикселей, но не нарушают их информационное наполнение.

Например, преобразование ColorJitter способно изменить яркость изображения на случайное число, что не изменяет его контекст. Однако, преобразование RandomCrop не рекомендуется, посколько есть шанс, что мордочка животного не попадет в фото и класс животного будет неоднозначен. Таким образом, при каждом вызове объекта из обучающей выборки к нему будет применяться случайное преобразование/серия случайных преобразований. Обратите внимание, что преобразование изображения должно быть согласованным с его сегментационной маской.

Требуется реализовать предлагаемые ниже преобразования аугментации:

  • HorizontalFlip (0.25 балла)
  • ColorJitter (0.25 балла)
  • RandomPerspective (0.5 балла)

Для каждого из указанных преобразований требуется написать магический метод __call__, который позволяет обращаться к объекту класса (преобразованию), как к функции (функтор из C++):

# инициализация
obj = Example()
# вызывается __call__
obj()
In [5]:
# Предлагается использовать эти функции
# Самому писать процедуры отражения картинки по вертикали/горизонтали или цветокоррекции не надо!
from torchvision.transforms.functional import hflip
from torchvision.transforms.functional import perspective
from torchvision.transforms import ColorJitter as CJ
from torchvision.transforms import RandomPerspective as RP
import random

class HorizontalFlip():
    def __init__(self, prob: float):
        self.p = prob

    def __call__(self, pair: tuple[Image, Image]) -> tuple[Image, Image]:
        """
        `pair` содержит пару (изображение, сегментационная маска)
        """
        """
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        if random.random() > self.p:
            return pair
        return hflip(pair[0]), hflip(pair[1])

    
class ColorJitter():
    def __init__(self, prob: float, param: tuple[float]):
        self.p = prob
        """
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        self.jitter = CJ(brightness=param[0], contrast=param[1], saturation=param[2])
    
    def __call__(self, pair: tuple[Image, Image]) -> tuple[Image, Image]:
        """
        `pair` содержит пару (изображение, сегментационная маска)
        """
        """
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        if random.random() > self.p:
            return pair
        return self.jitter(pair[0]), pair[1]


class RandomPerspective():
    def __init__(self, prob: float, param: float):
        self.p = prob
        """
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        self.distortion_scale = param
    def __call__(self, pair: tuple[Image, Image]) -> tuple[Image, Image]:
        """
        `pair` содержит пару (изображение, сегментационная маска)
        """
        """
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        if random.random() > self.p:
            return pair
        
        return [perspective(img, *RP.get_params(*pair[0].size, self.distortion_scale), interpolation=IM.NEAREST) for img in pair]

Применим реализованные преобразования и убедимся в их работоспособности:

In [7]:
img_idx = np.random.randint(0, 100)
f, ax = plt.subplots(2, 4, figsize=(16, 8))
pair = train_data_hdf5[img_idx]

imgs2draw = {"Source": pair,
            "HorizontalFlip": HorizontalFlip(1.0)(pair),
            "ColorJitter": ColorJitter(1.0, (0.4, 0.4, 0.4))(pair),
            "RandomPerspective": RandomPerspective(1.0, 0.25)(pair)
}
for idx, (name, pair) in enumerate(imgs2draw.items()):
    ax[0, idx].imshow(pair[0])
    ax[0, idx].set_title(name, fontsize=20)
    ax[1, idx].imshow(colorize(np.array(pair[1])))

plt.show()
/home/r.fazylov/anaconda3/envs/pl_vl_170/lib/python3.9/site-packages/torchvision/transforms/functional.py:594: UserWarning: torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.
torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in the returned tuple (although it returns other information about the problem).
To get the qr decomposition consider using torch.linalg.qr.
The returned solution in torch.lstsq stored the residuals of the solution in the last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the residuals in the field 'residuals' of the returned named tuple.
The unpacking of the solution, as in
X, _ = torch.lstsq(B, A).solution[:A.size(1)]
should be replaced with
X = torch.linalg.lstsq(A, B).solution (Triggered internally at  /pytorch/aten/src/ATen/LegacyTHFunctionsCPU.cpp:389.)
  res = torch.lstsq(b_matrix, a_matrix)[0]

Далее описываем наш класс SegmentationData и операции приведения изображений типа PIL.Image к pytorch тензорам с ImageNet нормализацией. ImageNet нормализация - это частный случай Standard normalization, в котором поканальное среднее (цветовые каналы red, green, blue) и поканальное среднеквадратическое отклонение вычислены на огромной выборке изображений.

Ответьте на вопрос: А для чего нужно применять нормализацию к изображениям?

Ваш ответ: сверточный слой, так же как и линейный, чувствителен к масштабу входных значений. Обучившись на одном распределении значений пикселей (масштабе), сеть успешно будет работать только для изображений из этого же масштаба. Поэтому мы и нормализуем входные изображения. Внутри сети за нормализацию отвечают уже нормализационные слои (BatchNorm, LayerNorm, etc.).
Кроме того, это еще и вычислительно стабильнее, благодаря распределению значений от -1 до 1 (mean=0, std=1)

In [53]:
# Определяем устройство для вычислений (!желательно GPU!)
DEVICE = 'cuda:1' if torch.cuda.is_available() else 'cpu'

t_dict = {
    "forward_input": transforms.Compose([
        transforms.PILToTensor(),
        transforms.Lambda(lambda x: x.float().to(DEVICE)/255.0),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    ]),
    "backward_input": transforms.Compose([
        transforms.Normalize(mean=[0.0, 0.0, 0.0],
                                     std=[1./0.229, 1./0.224, 1./0.225]),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406],
                                     std=[1.0, 1.0, 1.0]),
        transforms.Lambda(lambda x: x.permute(1, 2, 0).cpu().numpy())
    ]),
    "forward_target": transforms.Compose([
        transforms.PILToTensor(),
        transforms.Lambda(lambda x: x.long().squeeze().to(DEVICE)),
    ]),
    "backward_target": transforms.Compose([
        transforms.Lambda(lambda x: x.cpu().numpy())
    ]),
    "augment": transforms.Compose([
        HorizontalFlip(0.5),
        ColorJitter(0.5, (0.4, 0.4, 0.4)),
        RandomPerspective(0.5, 0.25)
    ]),
}


class SegmentationDataset(Dataset):
    def __init__(self, dataset_raw: Type[storage_class], transforms: dict, train_flag: bool = True):
        """
        Наследуем весь функционал из `Dataset` для наших данных `dataset_raw`
        `transforms` содержит преобразования PIL.Image <-> torch.tensor и аугментации
        `train_flag` регулирует аугментацию данных (для тестовой выборки она не нужна)
        """
        super().__init__()
        self.dataset_raw = dataset_raw
        self.transforms = transforms
        self.train_flag = train_flag

    def __len__(self):
        return self.dataset_raw.dataset_size
    
    def __getitem__(self, idx: int) -> tuple[Image, Image]:
        input, target = self.dataset_raw[idx]
        
        if (self.train_flag):
            input, target = self.transforms["augment"]((input, target))
            
        return self.transforms["forward_input"](input), self.transforms["forward_target"](target)
In [54]:
from torch.utils.data import random_split

# Разделяем обучающую выборку на обучающую и валидационную
def split_train_val(train_data: Type[storage_class], train_portion: float = 0.8):
    """
    `train_data` предобработанные данные
    `train_portion` доля объектов, которая будет приходиться на обучающую выборку
    """
    trainval_dataset = SegmentationDataset(train_data, t_dict, train_flag=True)
    
    train_size = int(len(trainval_dataset) * train_portion)
    val_size = len(trainval_dataset) - train_size
    return random_split(trainval_dataset, [train_size, val_size])

# train_dataset_hdf5, val_dataset_hdf5 = split_train_val(train_data_hdf5)
train_dataset_memmap, val_dataset_memmap = split_train_val(train_data_memmap)
# train_dataset_raw, val_dataset_raw = split_train_val(train_data_raw)
In [ ]:
from torch.utils.data import random_split

# Разделяем обучающую выборку на обучающую и валидационную
def split_train_val(train_data: Type[storage_class], train_portion: float = 0.8):
    """
    `train_data` предобработанные данные
    `train_portion` доля объектов, которая будет приходиться на обучающую выборку
    """
    trainval_dataset = SegmentationDataset(train_data, t_dict, train_flag=True)
    
    train_size = int(len(trainval_dataset) * train_portion)
    val_size = len(trainval_dataset) - train_size
    return random_split(trainval_dataset, [train_size, val_size])

train_dataset_hdf5, val_dataset_hdf5 = split_train_val(train_data_hdf5)
train_dataset_memmap, val_dataset_memmap = split_train_val(train_data_memmap)
train_dataset_raw, val_dataset_raw = split_train_val(train_data_raw)

Отрисуем случайное изображение (после применения случайных преобразований аугментации):

In [10]:
img_idx = np.random.randint(0, 100)
draw(train_dataset_hdf5[img_idx], t_dict);
/home/r.fazylov/anaconda3/envs/pl_vl_170/lib/python3.9/site-packages/torchvision/transforms/functional.py:165: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
  img = torch.as_tensor(np.asarray(pic))
In [ ]:
dataloader_config = {
    "batch_size": 16, # Ваше значение
    "shuffle": True,
    "num_workers": 0
}
train_dataloader_hdf5 = DataLoader(train_dataset_hdf5, **dataloader_config)
val_dataloader_hdf5 = DataLoader(val_dataset_hdf5, **dataloader_config)

train_dataloader_memmap = DataLoader(train_dataset_memmap, **dataloader_config)
val_dataloader_memmap = DataLoader(val_dataset_memmap, **dataloader_config)

train_dataloader_raw = DataLoader(train_dataset_raw, **dataloader_config)
val_dataloader_raw = DataLoader(val_dataset_raw, **dataloader_config)

1.3 Замер скорости чтения датасета с диска (0.5 балла)¶

Замерьте время чтения нашего датасета для каждого из форматов хранения:

In [12]:
def speedtest(dataloader: Type[DataLoader]) -> None:
    for batch in dataloader:
        pass
In [14]:
%timeit speedtest(train_dataloader_hdf5)
8.23 s ± 156 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [13]:
%timeit speedtest(train_dataloader_memmap)
9.48 s ± 164 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [15]:
%timeit speedtest(train_dataloader_raw)
10.8 s ± 208 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Ответьте на вопрос: Какой формат оказался самым эффективным по скорости? Почему?

Ваш ответ: самым эффективным оказался hdf5 благодаря тому что мы рандомно обращаемся к данным, то есть непоследовательно (shuffle=True). А он как раз и предназначен для этого

Создайте тестовый Dataloader победившего по скорости формата.

In [9]:
"""
==== YOUR CODE =====
     ¯\_(ツ)_/¯
"""
test_data_memmap = prepare_dataset(config_test, storage_memmap)
test_ds = SegmentationDataset(test_data_memmap, t_dict, train_flag=False)
test_dataloader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=0)
100%|██████████| 8/8 [02:44<00:00, 20.61s/it]

Часть 2: Реализация функций потерь, метрик и декодировщика PSPNet¶

Ранее вас познакомили с архитектурой Unet - сверточными автокодировщиком, применяемом в области сегментации изображений. В данном задании мы разберем более продвинутую архитектуру сети сегментации PSPNet. Отличительной особенностью этой сети является Pyramid Pooling Module, который, в отличие от Unet, позволяет учитывать глобальный контекст изображения при формировании признаков его локальных областей.

Рассмотрим предлагаемую архитектуру PSPNet-подобной сети: Image

В качестве кодировщика Encoder будем брать предобученную ResNeXt сеть. Будем его использовать для получения двух глубинных представлений нашего входого изображения x:

  • выход x_main - "среднее" промежуточное представление, компромисс между низкоуровневыми признаками (цвет, контуры объектов, штрихи) и высокоуровневыми признаками (абстрактные признаки, отражающие семантику изображения)
  • выход x_supp - финальное представление, содержащее самые высокоуровневые признаки, в которых значительно утеряна информация о точном простанственном расположении объектов

Подобное разбиение выхода на 2 потока объясняется необходимостью в закодированной информации о пространственном расположении объектов (x_main) и вспомогательной информации о семантике всего изображения в целом (x_supp) для задачи семантической сегментации. Мы не можем себе позволить использовать лишь выход x_supp, как это делается, например, в задачах классификации, ведь от нас требуется дополнительное знание о расположении этого объекта на изображении.

Ваша задача состоит в написании декодировщика Decoder, а именно в написании блоков:

  • Pyramid Pooling Module. К входному тензору x_main параллельно применяется несколько операций пулинга разных размеров, которые сводят пространственные размерности исходного тензора до размеров 1x1, 2x2, 3x3 и 6x6. Каналы промежуточных тензоров эффективно редуцируются (при помощи nn.Conv2d c размером фильтра 1x1), а затем пространственные размерности интерполируются до исходных размеров. Эта процедура необходима для извлечения глобального контекста разных масштабов, которого не хватает классическим сверточным нейронным сетям (локальный контекст в пределах размера фильтра). Таким образом, выходной тензор, полученный конкатенацией этих глобальных контекстов, содержит информацию о всем входном тензоре с разными уровнями детализации. Промежуточное редуцирование каналов тензоров производится для сжатия информации, а также для индивидуального взвешивания глобального контекста каждого масштаба. Требуется реализовать forward этап этого блока. Для уточнения информации можно обратиться к статье.
  • Supplementary Module осуществляет нелинейное преобразование над входным тензором x_supp с понижением числа каналов до размерности выхода модуля Pyramid Pooling Module. Вариант архитектуры этого преобразования (композиции слоев) уже предложен, но, при желании, вы можете с ним экспериментировать
  • Upsample Module осуществляет нелинейные преобразования над входным тензором с понижением числа каналов, которые чередуются с интерполяцией пространственных размерностей в 2 раза (увеличение). Таким образом, выход этого блока имеет ту же пространственную размерность, что и входное в кодировщик изображение. Это преобразование (слои Layer 0, Layer 1 и Layer 2) требуется экспериментально подобрать
  • Segmentation Head нелинейно преобразует входной тензор в тензор score'ов. Имеем, что выходной тензор для каждого пикселя имеет num_classes score'ов (в нашем случае 3). В дальнейшем, индекс максимального score'а для заданного пикселя и будет его меткой класса (0, 1 или 2). Это преобразование (композиция слоев) требуется экспериментально подобрать

Если декодировщик получается слишком тяжелый, то оператор Concat можно заменить на поканальное суммирование. Обратите внимание, что нет единственной правильной архитектуры указанных выше блоков. Требуется ее экспериментально подобрать так, чтобы получить наилучшее качество сегментации за разумную сложность.

2.1 Кодировщик и декодировщик PSPNet-подобной сети (2.5 балла)¶

In [10]:
from torchvision.models.resnet import ResNet
from torchvision.models import resnext50_32x4d

pretrained_model = resnext50_32x4d(pretrained=True)

# Выставляем evaluation mode (влияет на поведение таких слоев как BatchNorm2d, Dropout)
pretrained_model.eval()
Out[10]:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer2): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer3): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (4): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (5): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

Так как кодировщик используется предобученный, то требуется зафиксировать (заморозить) веса, чтобы по ним не тек градиент. Этим мы гарантируем, что кодировщик не изменяется в ходе обучения автокодировщика, а также экономим вычислительные ресурсы (граф градиента кодировщика не строится).

In [11]:
class EncoderBlock(nn.Module):
    def __init__(self, pretrained_model: Type[ResNet]):
        """
        Извлекает предобученные именованные слои кодировщика `pretrained_model`
        Разделяет слои на `main` и `supp` потоки (см. архитектуру выше)
        
        Вход: тензор (Batch_size, 3, Height, Width)
        
        Выход: x_main тензор (Batch_size, 512, Height // 8, Width // 8)
        Выход: x_supp тензор (Batch_size, 2048, Height // 32, Width // 32)
        """
        super().__init__()

        self.encoder_main = nn.Sequential()
        for name, child in list(pretrained_model.named_children())[:-4]:
            print(f"Pretrained main module {name} is loaded")
            self.encoder_main.add_module(name, child)
            
        self.encoder_supp = nn.Sequential()
        for name, child in list(pretrained_model.named_children())[-4:-2]:
            print(f"Pretrained supp module {name} is loaded")
            self.encoder_supp.add_module(name, child)
            
    def freeze(self) -> None:
        """
        Замораживает веса кодировщика
        """
        for p in self.parameters():
            p.requires_grad = False
        self.eval()
            
    def unfreeze(self) -> None:
        """
        Размораживает веса кодировщика
        """
        for p in self.parameters():
            p.requires_grad = True
        self.train()
            
    def forward(self, x: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
        x_main = self.encoder_main(x)
        x_supp = self.encoder_supp(x_main)
        return x_main, x_supp
In [12]:
encoder = EncoderBlock(pretrained_model)
Pretrained main module conv1 is loaded
Pretrained main module bn1 is loaded
Pretrained main module relu is loaded
Pretrained main module maxpool is loaded
Pretrained main module layer1 is loaded
Pretrained main module layer2 is loaded
Pretrained supp module layer3 is loaded
Pretrained supp module layer4 is loaded

Для оценки сложности модели нам понадобится функция подсчета числа ее параметров, для этого используйте метод .parameters(). Реализуйте ее ниже:

Также полезно было бы убедиться, что метод .parameters() возвращает то, что мы от него ожидаем. Для этого воспользуемся методом .named_parameters():

In [20]:
for name, parameter in encoder.named_parameters():
    print(name)
encoder_main.conv1.weight
encoder_main.bn1.weight
encoder_main.bn1.bias
encoder_main.layer1.0.conv1.weight
encoder_main.layer1.0.bn1.weight
encoder_main.layer1.0.bn1.bias
encoder_main.layer1.0.conv2.weight
encoder_main.layer1.0.bn2.weight
encoder_main.layer1.0.bn2.bias
encoder_main.layer1.0.conv3.weight
encoder_main.layer1.0.bn3.weight
encoder_main.layer1.0.bn3.bias
encoder_main.layer1.0.downsample.0.weight
encoder_main.layer1.0.downsample.1.weight
encoder_main.layer1.0.downsample.1.bias
encoder_main.layer1.1.conv1.weight
encoder_main.layer1.1.bn1.weight
encoder_main.layer1.1.bn1.bias
encoder_main.layer1.1.conv2.weight
encoder_main.layer1.1.bn2.weight
encoder_main.layer1.1.bn2.bias
encoder_main.layer1.1.conv3.weight
encoder_main.layer1.1.bn3.weight
encoder_main.layer1.1.bn3.bias
encoder_main.layer1.2.conv1.weight
encoder_main.layer1.2.bn1.weight
encoder_main.layer1.2.bn1.bias
encoder_main.layer1.2.conv2.weight
encoder_main.layer1.2.bn2.weight
encoder_main.layer1.2.bn2.bias
encoder_main.layer1.2.conv3.weight
encoder_main.layer1.2.bn3.weight
encoder_main.layer1.2.bn3.bias
encoder_main.layer2.0.conv1.weight
encoder_main.layer2.0.bn1.weight
encoder_main.layer2.0.bn1.bias
encoder_main.layer2.0.conv2.weight
encoder_main.layer2.0.bn2.weight
encoder_main.layer2.0.bn2.bias
encoder_main.layer2.0.conv3.weight
encoder_main.layer2.0.bn3.weight
encoder_main.layer2.0.bn3.bias
encoder_main.layer2.0.downsample.0.weight
encoder_main.layer2.0.downsample.1.weight
encoder_main.layer2.0.downsample.1.bias
encoder_main.layer2.1.conv1.weight
encoder_main.layer2.1.bn1.weight
encoder_main.layer2.1.bn1.bias
encoder_main.layer2.1.conv2.weight
encoder_main.layer2.1.bn2.weight
encoder_main.layer2.1.bn2.bias
encoder_main.layer2.1.conv3.weight
encoder_main.layer2.1.bn3.weight
encoder_main.layer2.1.bn3.bias
encoder_main.layer2.2.conv1.weight
encoder_main.layer2.2.bn1.weight
encoder_main.layer2.2.bn1.bias
encoder_main.layer2.2.conv2.weight
encoder_main.layer2.2.bn2.weight
encoder_main.layer2.2.bn2.bias
encoder_main.layer2.2.conv3.weight
encoder_main.layer2.2.bn3.weight
encoder_main.layer2.2.bn3.bias
encoder_main.layer2.3.conv1.weight
encoder_main.layer2.3.bn1.weight
encoder_main.layer2.3.bn1.bias
encoder_main.layer2.3.conv2.weight
encoder_main.layer2.3.bn2.weight
encoder_main.layer2.3.bn2.bias
encoder_main.layer2.3.conv3.weight
encoder_main.layer2.3.bn3.weight
encoder_main.layer2.3.bn3.bias
encoder_supp.layer3.0.conv1.weight
encoder_supp.layer3.0.bn1.weight
encoder_supp.layer3.0.bn1.bias
encoder_supp.layer3.0.conv2.weight
encoder_supp.layer3.0.bn2.weight
encoder_supp.layer3.0.bn2.bias
encoder_supp.layer3.0.conv3.weight
encoder_supp.layer3.0.bn3.weight
encoder_supp.layer3.0.bn3.bias
encoder_supp.layer3.0.downsample.0.weight
encoder_supp.layer3.0.downsample.1.weight
encoder_supp.layer3.0.downsample.1.bias
encoder_supp.layer3.1.conv1.weight
encoder_supp.layer3.1.bn1.weight
encoder_supp.layer3.1.bn1.bias
encoder_supp.layer3.1.conv2.weight
encoder_supp.layer3.1.bn2.weight
encoder_supp.layer3.1.bn2.bias
encoder_supp.layer3.1.conv3.weight
encoder_supp.layer3.1.bn3.weight
encoder_supp.layer3.1.bn3.bias
encoder_supp.layer3.2.conv1.weight
encoder_supp.layer3.2.bn1.weight
encoder_supp.layer3.2.bn1.bias
encoder_supp.layer3.2.conv2.weight
encoder_supp.layer3.2.bn2.weight
encoder_supp.layer3.2.bn2.bias
encoder_supp.layer3.2.conv3.weight
encoder_supp.layer3.2.bn3.weight
encoder_supp.layer3.2.bn3.bias
encoder_supp.layer3.3.conv1.weight
encoder_supp.layer3.3.bn1.weight
encoder_supp.layer3.3.bn1.bias
encoder_supp.layer3.3.conv2.weight
encoder_supp.layer3.3.bn2.weight
encoder_supp.layer3.3.bn2.bias
encoder_supp.layer3.3.conv3.weight
encoder_supp.layer3.3.bn3.weight
encoder_supp.layer3.3.bn3.bias
encoder_supp.layer3.4.conv1.weight
encoder_supp.layer3.4.bn1.weight
encoder_supp.layer3.4.bn1.bias
encoder_supp.layer3.4.conv2.weight
encoder_supp.layer3.4.bn2.weight
encoder_supp.layer3.4.bn2.bias
encoder_supp.layer3.4.conv3.weight
encoder_supp.layer3.4.bn3.weight
encoder_supp.layer3.4.bn3.bias
encoder_supp.layer3.5.conv1.weight
encoder_supp.layer3.5.bn1.weight
encoder_supp.layer3.5.bn1.bias
encoder_supp.layer3.5.conv2.weight
encoder_supp.layer3.5.bn2.weight
encoder_supp.layer3.5.bn2.bias
encoder_supp.layer3.5.conv3.weight
encoder_supp.layer3.5.bn3.weight
encoder_supp.layer3.5.bn3.bias
encoder_supp.layer4.0.conv1.weight
encoder_supp.layer4.0.bn1.weight
encoder_supp.layer4.0.bn1.bias
encoder_supp.layer4.0.conv2.weight
encoder_supp.layer4.0.bn2.weight
encoder_supp.layer4.0.bn2.bias
encoder_supp.layer4.0.conv3.weight
encoder_supp.layer4.0.bn3.weight
encoder_supp.layer4.0.bn3.bias
encoder_supp.layer4.0.downsample.0.weight
encoder_supp.layer4.0.downsample.1.weight
encoder_supp.layer4.0.downsample.1.bias
encoder_supp.layer4.1.conv1.weight
encoder_supp.layer4.1.bn1.weight
encoder_supp.layer4.1.bn1.bias
encoder_supp.layer4.1.conv2.weight
encoder_supp.layer4.1.bn2.weight
encoder_supp.layer4.1.bn2.bias
encoder_supp.layer4.1.conv3.weight
encoder_supp.layer4.1.bn3.weight
encoder_supp.layer4.1.bn3.bias
encoder_supp.layer4.2.conv1.weight
encoder_supp.layer4.2.bn1.weight
encoder_supp.layer4.2.bn1.bias
encoder_supp.layer4.2.conv2.weight
encoder_supp.layer4.2.bn2.weight
encoder_supp.layer4.2.bn2.bias
encoder_supp.layer4.2.conv3.weight
encoder_supp.layer4.2.bn3.weight
encoder_supp.layer4.2.bn3.bias
In [13]:
def count_parameters(model: Type[nn.Module]) -> int:
    """
    Считает число весов в модели `model`, для которых требуется градиент
    """
    total_params = 0
    """
    ==== YOUR CODE =====
         ¯\_(ツ)_/¯
    """
    for name, param in model.named_parameters():
          if param.requires_grad:
               total_params += param.numel() 
    return total_params

Убедимся, что метод .freeze() успешно замораживает веса:

In [22]:
print("Encoder #parameters before freeze():", count_parameters(encoder))
encoder.freeze()
print("Encoder #parameters after freeze():", count_parameters(encoder))
Encoder #parameters before freeze(): 22979904
Encoder #parameters after freeze(): 0

Реализуйте PyramidPoolingModule, Upsample и SegmentationHead (по 0.5 балла), а также заполните пропущенные значения ??? в UpsampleModule и DecoderBlock (по 0.25 балла). Выбор параметров/архитектуры сети в большей степени зависит от результатов обучения в следующей части задания (так что вы еще вернетесь к этому пункту).

In [14]:
class PyramidPoolingModule(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, bin_sizes: tuple[int, ...]):
        """
        Вход: тензор (Batch_size, `in_channels`, Height, Width)
        `bin_sizes` - пространственные размерности для каждой пулинг операции
        Пример: bin_sizes = (1, 2, 3, 6).
        
        Выход: тензор (Batch_size, `in_channels` + len(`bin_sizes`) * `out_channels`, Height, Width)
        """
        super().__init__()
        self.bins = []
        
        for bin_size in bin_sizes:
            self.bins.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin_size),
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ))
            
        self.bins = nn.ModuleList(self.bins)

    def forward(self, x: torch.tensor) -> torch.tensor:
        h, w = x.shape[2:]
        out = []
        """
        Осуществите все пулинг-операции с последующим `Upscale`
        Подсказка: используйте torch.functional.interpolate
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        for bin in self.bins:
            out.append(F.interpolate(bin(x), size=(h, w), mode='bilinear', align_corners=True))
        return torch.cat([x] + out, dim=1)

    
class SupplementaryModule(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, dropout: float):
        """
        Вход: тензор (Batch_size, `in_channels`, Height, Width)
        
        Выход: тензор (Batch_size, `out_channels`, Height, Width)
        """
        super().__init__()
        mid_channels = 512 # TODO: in_channels // 2 (1024), 512 / 256
        self.suppl = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout), # TODO: убрать dropout
            nn.Conv2d(mid_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        """
        Указанную выше архитектуру можно менять по своему усмотрению
        """

    def forward(self, x: torch.tensor) -> torch.tensor:
        return self.suppl(x)
    
    
class Upsample(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        """
        Вход: тензор (Batch_size, `in_channels`, Height, Width)
        
        Выход: тензор (Batch_size, `out_channels`, 2 * Height, 2 * Width)
        """
        super().__init__()
        self.us_transform = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), # TODO: 2xConv2d / ConvTransposed2d + 1xConv2d
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x: torch.tensor) -> torch.tensor:
        """
        Подсказка: используйте torch.functional.interpolate для удвоения пространственных размерностей
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        w, h = x.shape[2:]
        x = self.us_transform(x)
        x = F.interpolate(x, size=(2*w, 2*h), mode='bilinear', align_corners=True)
        return x
    
    
class UpsampleModule(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        """
        Вход: тензор (Batch_size, `in_channels`, Height, Width)
        
        Выход: тензор (Batch_size, `out_channels`, 8 * Height, 8 * Width)
        """
        super().__init__()
        m1_channels = in_channels // 2 # TODO: так как входных каналов мало (PPM_out_channels + supp_out_channels)
        m2_channels = m1_channels // 2
        self.upsample = nn.Sequential(
            Upsample(in_channels, m1_channels),
            Upsample(m1_channels, m2_channels),
            Upsample(m2_channels, out_channels)
        )

    def forward(self, x: torch.tensor) -> torch.tensor:
        return self.upsample(x)
In [15]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, bin_sizes: tuple[int, ...], dropout: float = 0.1):
        """
        Вход  x_main: тензор (Batch_size, `in_channels`, Height, Width)
        Вход  x_supp: тензор (Batch_size, 4 * `in_channels`, Height // 4, Width // 4)
        
        Выход: тензор (Batch_size, `out_channels`, 8 * Height, 8 * Width)
        """
        super().__init__()
        assert in_channels % len(bin_sizes) == 0 # in_channels = 512

        bin_out_channels = 1 # в статье 1
        PPM_out_channels = len(bin_sizes) * bin_out_channels + in_channels # 516
        supp_out_channels = 32 # абстрактная информация о классе, не очень нужна имхо, 64 / 32 хватит
        self.PPM = PyramidPoolingModule(in_channels, bin_out_channels, bin_sizes)
        self.SM = SupplementaryModule(4 * in_channels, supp_out_channels, dropout)
        self.UM = UpsampleModule(PPM_out_channels + supp_out_channels, out_channels) 
        
    def forward(self, x_main: torch.tensor, x_supp: torch.tensor) -> torch.tensor:
        h_supp, w_supp = x_supp.shape[2:]
        x_supp = F.interpolate(input=x_supp, size=(4 * h_supp, 4 * w_supp), mode='bilinear', align_corners=True)
        
        x_supp = self.SM(x_supp)
        x_main = self.PPM(x_main)
        
        out = self.UM(torch.cat([x_main, x_supp], dim=1))
        return out

    
class SegmentationHead(nn.Module):
    def __init__(self, in_channels: int, num_classes: int, dropout: float = 0.0):
        """
        Вычисляет score для каждого из классов
        Вход: тензор (Batch_size, `in_channels`, Height, Width)
        
        Выход: тензор (Batch_size, `num_classes`, Height, Width)
        """
        super().__init__()
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels // 2),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout),
            nn.Conv2d(in_channels // 2, num_classes, kernel_size=1),
        )

    def forward(self, x: torch.tensor, x_supp: torch.tensor) -> torch.tensor:
        """
        На будущее зададим фиктивный аргумент `x_supp`, который пока не будем использовать
        """
        return self.segmentation_head(x)

2.2 Реализация метрик (3.5 балла)¶

В задаче сегментации для оценки предсказательной способности нейронной сети, в основном, используют следующие метрики:

Пусть $\mathrm{P}$ обозначает прогноз сег. маски (Prediction), $\mathrm{S}$ обозначает score'ы для каждого класса сег. маски (Scores), а $\mathrm{T}$ означает сег. маску (Target). Тогда:

  • Intersection over Union metric (коэффициент Жаккара): $$ \mathrm{IoU}(P, T) = \dfrac{\sum_{i=1}^{M}\sum_{j=1}^{N}[P_{ij}*T_{ij}]}{\sum_{i=1}^{M}\sum_{j=1}^{N} [P_{ij} + T_{ij} - P_{ij}*T_{ij}]}\text{, где } P, T \in \{0, 1\}^{M \times N} $$
  • Recall metric (полнота): $$ \mathrm{Recall}(P, T) = \dfrac{\sum_{i=1}^{M}\sum_{j=1}^{N}[P_{ij} * T_{ij}]}{\sum_{i=1}^{M}\sum_{j=1}^{N} T_{ij}}\text{, где } P, T \in \{0, 1\}^{M \times N} $$ Указанные выше метрики расписаны для случая бинарной сегментации, которая нам не подходит. Обобщим их на случай мультиклассовой сегментации: представим K-классовую задачу как K двухклассовых, а затем макро- или микро-усредним для них метрики. Требуется реализовать мультиклассовые варианты указанных метрик с поддержкой макро- и микро-усреднения (по 1 баллу). Обратите внимание, что метрики рассчитываются для каждого элемента из батча. За редуцирование метрик вдоль размерности батча отвечает аргумент reduce (см. ниже).

Также для обучения будем использовать две разные, но схожие функции потерь:

  • Cross Entropy Loss (кросс-энтропия): $$ \mathrm{CE}(S, T) = - \dfrac{1}{MN}\sum_{c=1}^{K}\sum_{i=1}^{M}\sum_{j=1}^{N} \big[\log \mathrm{Softmax}(S)_{cij}*\mathbb{I}[T_{ij} == c]\big]\text{, где } S \in \mathbb{R}^{K \times M \times N}, T \in \{1, ..., K\}^{M \times N} $$
  • Focal Loss: $$ \mathrm{FL}(S, T) = - \dfrac{1}{MN}\sum_{c=1}^{K}\sum_{i=1}^{M}\sum_{j=1}^{N} \big[(1 - \mathrm{Softmax}(S)_{cij})^{\gamma}*\log \mathrm{Softmax}(S)_{cij}*\mathbb{I}[T_{ij} == c]\big]\text{, где } S \in \mathbb{R}^{K \times M \times N}, T \in \{1, ..., K\}^{M \times N}, \gamma \in \mathbb{R}_{+} - \text{гиперпараметр} $$

Требуется реализовать обе функции потерь. Также всюду необходимо обеспечить корректную обработку значений ignore_index, которые в нашем случае равны 255 (не участвуют в расчете метрик/функций потерь). Если представители некоторых классов в $\mathrm{T}$ отсутствуют, то учитывать эти классы при макро-усреднении не нужно.

In [16]:
class MetricsCollection():
    def __init__(self, num_classes: int, ignore_index: int = 255):
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        
    def IoUMetric(self, prediction: torch.tensor, target: torch.tensor, average: str = "macro", reduce: str = "mean") -> Union[torch.tensor, float]:
        """
        `prediction` предсказанная сегментационная маска размера (Batch_size, Height, Width)
        `target` истинная сегментационная маска размера (Batch_size, Height, Width)
        `average` тип мультклассового усреднения
        `reduce` редукция значений метрики вдоль размерности Batch; None - без редукции
        """
        assert average in ["micro", "macro"]
        assert reduce in ["sum", "mean", "none"]
        
        """
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        ignore_idx = torch.where(target == self.ignore_index)
        n_batch = prediction.shape[0]

        nums = torch.zeros(n_batch, self.num_classes)
        dens = torch.zeros(n_batch, self.num_classes)
        dens_macro = torch.zeros(n_batch, self.num_classes)

        classes_in_target = torch.ones(n_batch, 1) * self.num_classes
        
        for c in range(self.num_classes):
            pred_c = prediction.clone()
            pred_c[prediction == c] = 1
            pred_c[prediction != c] = 0
            pred_c[ignore_idx] = 0

            target_c = target.clone()
            target_c[target == c] = 1
            target_c[target != c] = 0
            target_c[ignore_idx] = 0
            
            nums[:, c] = (pred_c * target_c).sum(dim=(1, 2))
            dens[:, c] = (pred_c + target_c - pred_c * target_c).sum(dim=(1, 2))

            not_in_target = (target_c.sum(dim=(1,2)) == 0).view(-1)
            classes_in_target[not_in_target] -= 1
            dens_macro[:, c] = dens[:, c]
            dens_macro[not_in_target, c] = 1  # prevent zero division

        if average == 'micro':
            iou = nums.sum(dim=1) / dens.sum(dim=1)
        else:
            iou = (nums / dens_macro)
            iou = iou.sum(dim=1)
            iou = iou / classes_in_target

        if reduce == 'sum':
            return iou.sum()
        elif reduce == 'mean':
            return iou.mean()
        return iou
            
    def RecallMetric(self, prediction: torch.tensor, target: torch.tensor, average: str = "macro", reduce: str = "mean") -> Union[torch.tensor, float]:
        """
        `prediction` предсказанная сегментационная маска размера (Batch_size, Height, Width)
        `target` истинная сегментационная маска размера (Batch_size, Height, Width)
        `average` тип мультклассового усреднения
        `reduce` редукция значений метрики вдоль размерности Batch; None - без редукции
        """
        assert average in ["micro", "macro"]
        assert reduce in ["sum", "mean", "none"]
        
        """
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        ignore_idx = torch.where(target == self.ignore_index)
        n_batch = prediction.shape[0]

        nums = torch.zeros(n_batch, self.num_classes)
        dens = torch.zeros(n_batch, self.num_classes)
        dens_macro = torch.zeros(n_batch, self.num_classes)

        classes_in_target = torch.ones(n_batch, 1) * self.num_classes

        for c in range(self.num_classes):
            pred_c = prediction.clone()
            pred_c[prediction == c] = 1
            pred_c[prediction != c] = 0
            pred_c[ignore_idx] = 0
            
            target_c = target.clone()
            target_c[target == c] = 1
            target_c[target != c] = 0
            target_c[ignore_idx] = 0
            
            nums[:, c] = (pred_c * target_c).sum(dim=(1, 2))
            dens[:, c] = (target_c).sum(dim=(1, 2))

            not_in_target = (target_c.sum(dim=(1,2)) == 0).view(-1)
            classes_in_target[not_in_target] -= 1
            dens_macro[:, c] = dens[:, c]
            dens_macro[not_in_target, c] = 1  # prevent zero division

        if average == 'micro':
            iou = nums.sum(dim=1) / dens.sum(dim=1)
        else:
            iou = (nums / dens_macro)
            iou = iou.sum(dim=1)
            iou = iou / classes_in_target
        
        if reduce == 'sum':
            return iou.sum()
        elif reduce == 'mean':
            return iou.mean()
        return iou
    
    def FocalLoss(self, scores: torch.tensor, target: torch.tensor, reduce: str = "mean", gamma: float = 1.) -> Union[torch.tensor, float]:
        """
        `scores` score'ы каждого класса сегментационной маски размера (Batch_size, num_classes, Height, Width)
        `target` истинная сегментационная маска размера (Batch_size, Height, Width)
        `reduce` редукция значений функции потерь вдоль размерности Batch; None - без редукции
        """
        assert scores.shape[1] == self.num_classes
        assert reduce in ["sum", "mean", "none"]
        
        ce_loss = F.cross_entropy(scores, target, ignore_index=self.ignore_index, reduction="none")
        coef = (1 - torch.exp(-ce_loss))**gamma
        focal_loss = coef * ce_loss
        norm = (focal_loss.numel() - (target == self.ignore_index).sum())

        if (reduce == "sum"):
            return focal_loss.sum() / norm * scores.shape[0]
        elif (reduce == "mean"):
            return focal_loss.sum() / norm
        else:
            return focal_loss.sum(dim=[1, 2]) / norm * scores.shape[0]

    def CrossEntropyLoss(self, scores: torch.tensor, target: torch.tensor, reduce: str = "mean") -> Union[torch.tensor, float]:
        """
        `scores` score'ы каждого класса сегментационной маски размера (Batch_size, num_classes, Height, Width)
        `target` истинная сегментационная маска размера (Batch_size, Height, Width)
        `reduce` редукция значений функции потерь вдоль размерности Batch; None - без редукции
        """
        assert scores.shape[1] == self.num_classes
        assert reduce in ["sum", "mean", "none"]

        if (reduce == "sum"):
            return F.cross_entropy(scores, target, ignore_index=self.ignore_index, reduction="mean") * scores.shape[0]
        elif (reduce == "mean"):
            return F.cross_entropy(scores, target, ignore_index=self.ignore_index, reduction="mean")
        else:
            return F.cross_entropy(scores, target, ignore_index=self.ignore_index, reduction="none")

    @classmethod
    def ListMetrics(cls):
        return [method for method in dir(cls) if (method.endswith("Metric"))]
    
    @classmethod
    def ListLosses(cls):
        return [method for method in dir(cls) if (method.endswith("Loss"))]
In [17]:
metric_class = MetricsCollection(num_classes=3, ignore_index=255)

prediction = torch.tensor([[[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
                           [[0, 0, 0, 0], [0, 2, 2, 0], [0, 2, 0, 0], [0, 0, 0, 0]]])

target = torch.tensor([[[0, 0, 0, 0], [0, 1, 255, 0], [0, 1, 255, 0], [0, 0, 0, 0]],
                       [[0, 0, 0, 0], [0, 255, 2, 0], [0, 255, 2, 0], [0, 0, 0, 0]]])

assert np.isclose(metric_class.RecallMetric(prediction, target, "micro", "mean").item(), 0.9286, atol=1e-3)
assert np.isclose(metric_class.RecallMetric(prediction, target, "macro", "mean").item(), 0.7500, atol=1e-3)
assert np.isclose(metric_class.IoUMetric(prediction, target, "micro", "mean").item(), 0.8667, atol=1e-3)
assert np.isclose(metric_class.IoUMetric(prediction, target, "macro", "mean").item(), 0.7115, atol=1e-3)

Ответьте на вопрос (№1): Что говорит о предсказательной способности нашей сети ситуация: высокий Recall и низкий IoU для некоторого класса? Возможна ли обратная ситуация?

Ваш ответ: это говорит о том, что сеть предсказывает объект размера больше, чем он есть на самом деле. Обратная ситуация, низкий recall и высокий iou, быть не может, так как если маленький recall, то будет и малеьникй P * T (числитель IoU).

Ответьте на вопрос (№2): Какой вид усреднения правильней использовать в нашей задаче: макро и микро? Почему?

Ваш ответ: макро, так как пиксели-классы внутри одного изображения всегда будут несбалансированны. Макро лучше подходит для дисбаланса, ошибки на всех классах влияют на макро-метрику равнозначно. А микро же благодаря усреднению по всем семплам всех классов будет нивелировать возможные ошибки в маленьких классах.

Ответьте на вопрос (№3): В чем преимущество Focal Loss перед Cross Entropy Loss? Что контроллирует гиперпараметр 𝛾 в Focal Loss?

Ваш ответ: Focal Loss помогает при дисбалансе классов, это благодаря тому, что он хорошо штрафует за ошибки, при этом одинаково поощряет за уверенные ответы (скажем за если модель выдала p>=0.8, то ответ уверенный и при ground truth=1 поощряться такие ответы будут одинаково). Гиперпараметр \gamma отвечает как раз таки с какого p мы считаем, что ответ уверенный

Часть 3: Обучение PSPNet, эксперименты¶

Теперь осталось лишь собрать все написанное ранее воедино и обучить нашу сеть. Чтобы контроллировать процесс обучения нашей сети, будем вычислять усредненные метрики и функции потерь на валидационной выборке. Для удобства отображения информации воспользуемся инструментом tensorboard. Для этого заведем объект класса SummaryWriter, который создаст и откроет на запись специальный event файл для tensorboard. Для визуализации содержимого вводится команда tensorboard --logdir=<PATH> в терминале. Если возникла необходимость в мониториге нескольких tensorboard, то каждому из них требуется присвоить свой уникальный порт --port <PORT>. Пример использования tensorboard на Google Colab.

Требуется написать методы train_model и test_model. Вся конфигурация обучения хранится в словаре train_config. При желании его можно дополнить чем-то своим.

К вашему решению потребуется прикрепить логи tensorboard. Чтобы облегчить процедуру проверки настоятельно рекомендуется пользоваться inline-tensorboard:

%load_ext tensorboard
%tensorboard --logdir ./runs

3.1 Реализация процедур обучения/тестирования сети (1 балл)¶

In [23]:
class PSPNet(nn.Module):
    def __init__(self, pretrained_model: Type[ResNet], HeadBlock: Type[nn.Module], num_classes: int, train_config: dict, bin_sizes: tuple[int, ...] = (1, 2, 3, 6)):
        """
        `pretrained_model` модель предобученного кодировщика
        `Head` класс блока, оценивающего score'ы для каждого класса сегментационной маски
        `num_class` число классов сегментации
        `train_config` словарь с конфигурацией процесса обучения сети
        `bin_sizes` пространственные размеры к которым сводит пулинг в блоке PPM
        """
        super().__init__()
        self.encoder = EncoderBlock(pretrained_model)
        self.encoder.freeze()
        mid_channels = 256 # TODO: 256 / 128 / 64
        self.decoder = DecoderBlock(512, mid_channels, bin_sizes) 
        self.head = HeadBlock(mid_channels, num_classes)
        
        self.train_config = train_config
        self.metric_class = train_config["metric_class"]
        self.optimizer = train_config["optimizer"](self.parameters(), **train_config["optimizer_params"])
        self.scheduler = train_config["scheduler"](self.optimizer, **train_config["scheduler_params"])
        
    def forward(self, x: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
        # Для гарантии отсутствия градиентов по кодировщику
        with torch.no_grad():
            x_main, x_supp = self.encoder(x)
        out = self.decoder(x_main, x_supp)
        out = self.head(out, x_supp)
        return out, torch.argmax(out.detach(), dim=1)
    
    def write_val_metrics(self, val_metrics: dict, iter_num: int, norm: float = 1.0) -> None:
        """
        Записывает усредненные значения метрик/функций потерь в tensorboard
        
        `val_metrics` словарь с ключами "название_метрики/функции потерь" и их значениями
        `iter_num` номер глобальной итерации (по формуле #всего_итераций * номер_эпохи + номер_итерации)
        `norm` фактор нормализации; для усреднения равен числу объектов в валидационной выборке
        """
        for method, value in val_metrics.items():
            self.train_config["writer"].add_scalar(f"Mean {method}", np.round(val_metrics[method].item()/norm, 2), iter_num)
    
    def validate_model(self, val_dataloader: Type[DataLoader], iter_num: int) -> None:
        """
        Валидирует текущую модель и вычисляет соответствующие метрики/функции потерь
        
        `val_dataloader` валидационная выборка
        `iter_num` номер глобальной итерации (по формуле #всего_итераций * номер_эпохи + номер_итерации)
        """
        # Выставляет декодировщик в режим валидации (влияет на поведение BatchNorm2d и Dropout)
        self.decoder.eval()
        
        # Инициализация словаря метрик/функций потерь 
        val_metrics = dict([(method, 0.0) for method in (self.metric_class.ListMetrics() + self.metric_class.ListLosses())])
        
        # Обязательно считать с контекстным менеджером torch.no_grad()
        # Даже если мы не делаем шаг оптимизации, мы экономим память (не считаем градиенты)
        with torch.no_grad():
            for input, target in val_dataloader:
                scores, prediction = self.forward(input)
                for metric in self.metric_class.ListMetrics():
                    val_metrics[metric] += getattr(self.metric_class, metric)(prediction, target, reduce="sum")
                    
                for loss in self.metric_class.ListLosses():
                    val_metrics[loss] += getattr(self.metric_class, loss)(scores, target, reduce="sum")
        
        # Tensorboard также позволяет сохранять визуализацию наших предсказаний в ходе обучения
        figure = draw((input[0], target[0]), t_dict, prediction[0], log=True)
        self.train_config["writer"].add_figure("image/GT/prediction", figure, iter_num)
        
        self.write_val_metrics(val_metrics, iter_num, norm=len(val_dataloader.dataset))
        # Возвращает режим обучения декодировщика
        self.decoder.train()
        
    def train_model(self, train_dataloader: Type[DataLoader], val_dataloader: Type[DataLoader]) -> None:
        """
        Обучает модель на обучающей выборке, периодически (периодичность выставляется в train_config) валидирует на валидационной выборке
        В конце каждой эпохи сохраняет модель на диск
        
        `train_dataloader` обучающая выборка
        `val_dataloader` валидационная выборка
        """
        # Выставляет режим обучения декодировщика
        self.decoder.train()
        
        for epoch in range(self.train_config["num_epochs"]):
            for iter_num, (input, target) in enumerate(train_dataloader):
                self.optimizer.zero_grad()
                """
                ==== YOUR CODE =====
                     ¯\_(ツ)_/¯
                """
                scores, pred = self.forward(input)
                loss = self.train_config['loss_fn'](scores, target)
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()

                if (iter_num % self.train_config["validate_each_iter"] == 0):
                    print(f"Epoch: {epoch+1}/{self.train_config['num_epochs']} || Iter: {iter_num}/{len(train_dataloader)} || Loss: {loss.item()}")
                    self.validate_model(val_dataloader, epoch * len(train_dataloader) + iter_num)
                    
            torch.save(self.state_dict(), self.train_config["save_model_path"] + f"_{epoch+1}.pth")
                        
    def test_model(self, test_dataloader: Type[DataLoader]) -> tuple[torch.tensor, torch.tensor]:
        """
        Inference модели на тестовой выборке. Возвращает тензор предсказаний сег.масок и тензор истинных сег.масок
        
        `test_dataloader` тестовая выборка
        """
        # Выставляет декодировщик в режим валидации (влияет на поведение BatchNorm2d и Dropout)
        self.decoder.eval()
        """
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        dl_prediction = []
        dl_target = []
        with torch.no_grad():
            for iter_num, (input, target) in enumerate(test_dataloader):
                scores, pred = self.forward(input)
                dl_prediction.append(pred.detach().cpu())
                dl_target.append(target.detach().cpu())
        
        dl_prediction = torch.cat(dl_prediction, dim=0)
        dl_target = torch.cat(dl_target, dim=0)

        return dl_prediction, dl_target

3.2 Обучение PSPNet, эксперименты (6 баллов)¶

Вам приведены начальные значения гиперпараметров сети. Подберите гиперпараметры (если необходимо) и обучите сеть на обе функции потерь CrossEntropyLoss и FocalLoss. Добейтесь следующих результатов на тестовой выборке хотя бы для одной из них:

  • Mean IoU metric > 0.87
  • Mean Recall metric > 0.96

К вашему решению требуется прикрепить логи tensorboard.

In [34]:
from torch.optim.lr_scheduler import StepLR

train_config = {
    "num_epochs": 5,
    "optimizer": torch.optim.Adam,
    "optimizer_params": {
        "lr": 1e-3,
        "weight_decay": 1e-5
    },
    "loss_fn": metric_class.FocalLoss, # or metric_class.FocalLoss
    "scheduler": StepLR,
    "scheduler_params": {
        "step_size": 50,
        "gamma": 0.85
    },
    "validate_each_iter": 10,
    "writer": SummaryWriter(comment="Floss"), #Floss
    "save_model_path": 'FLoss.pth',
    "metric_class": metric_class
}

net = PSPNet(pretrained_model, SegmentationHead, num_classes=3, train_config=train_config).to(DEVICE)
print("#параметров в сети:", count_parameters(net))
Pretrained main module conv1 is loaded
Pretrained main module bn1 is loaded
Pretrained main module relu is loaded
Pretrained main module maxpool is loaded
Pretrained main module layer1 is loaded
Pretrained main module layer2 is loaded
Pretrained supp module layer3 is loaded
Pretrained supp module layer4 is loaded
#параметров в сети: 11759802
In [35]:
net.train_model(train_dataloader_memmap, val_dataloader_memmap)
Epoch: 1/5 || Iter: 0/184 || Loss: 0.809173047542572
Epoch: 1/5 || Iter: 10/184 || Loss: 0.4578273892402649
Epoch: 1/5 || Iter: 20/184 || Loss: 0.20597709715366364
Epoch: 1/5 || Iter: 30/184 || Loss: 0.1933986395597458
Epoch: 1/5 || Iter: 40/184 || Loss: 0.19391778111457825
Epoch: 1/5 || Iter: 50/184 || Loss: 0.13917164504528046
Epoch: 1/5 || Iter: 60/184 || Loss: 0.09061640501022339
Epoch: 1/5 || Iter: 70/184 || Loss: 0.10481450706720352
Epoch: 1/5 || Iter: 80/184 || Loss: 0.14566083252429962
Epoch: 1/5 || Iter: 90/184 || Loss: 0.11853332817554474
Epoch: 1/5 || Iter: 100/184 || Loss: 0.09860879927873611
Epoch: 1/5 || Iter: 110/184 || Loss: 0.12638220191001892
Epoch: 1/5 || Iter: 120/184 || Loss: 0.10887762904167175
Epoch: 1/5 || Iter: 130/184 || Loss: 0.08282797038555145
Epoch: 1/5 || Iter: 140/184 || Loss: 0.09824994951486588
Epoch: 1/5 || Iter: 150/184 || Loss: 0.09718429297208786
Epoch: 1/5 || Iter: 160/184 || Loss: 0.16996942460536957
Epoch: 1/5 || Iter: 170/184 || Loss: 0.12718182802200317
Epoch: 1/5 || Iter: 180/184 || Loss: 0.10510507225990295
Epoch: 2/5 || Iter: 0/184 || Loss: 0.08859036862850189
Epoch: 2/5 || Iter: 10/184 || Loss: 0.13131177425384521
Epoch: 2/5 || Iter: 20/184 || Loss: 0.11125364899635315
Epoch: 2/5 || Iter: 30/184 || Loss: 0.0709814578294754
Epoch: 2/5 || Iter: 40/184 || Loss: 0.17172373831272125
Epoch: 2/5 || Iter: 50/184 || Loss: 0.08666875213384628
Epoch: 2/5 || Iter: 60/184 || Loss: 0.11063261330127716
Epoch: 2/5 || Iter: 70/184 || Loss: 0.10458337515592575
Epoch: 2/5 || Iter: 80/184 || Loss: 0.1034105122089386
Epoch: 2/5 || Iter: 90/184 || Loss: 0.1510966271162033
Epoch: 2/5 || Iter: 100/184 || Loss: 0.09208666533231735
Epoch: 2/5 || Iter: 110/184 || Loss: 0.4372779130935669
Epoch: 2/5 || Iter: 120/184 || Loss: 0.13854968547821045
Epoch: 2/5 || Iter: 130/184 || Loss: 0.08269744366407394
Epoch: 2/5 || Iter: 140/184 || Loss: 0.0856635645031929
Epoch: 2/5 || Iter: 150/184 || Loss: 0.08284060657024384
Epoch: 2/5 || Iter: 160/184 || Loss: 0.08710094541311264
Epoch: 2/5 || Iter: 170/184 || Loss: 0.12266074120998383
Epoch: 2/5 || Iter: 180/184 || Loss: 0.11962353438138962
Epoch: 3/5 || Iter: 0/184 || Loss: 0.0773673728108406
Epoch: 3/5 || Iter: 10/184 || Loss: 0.07409332692623138
Epoch: 3/5 || Iter: 20/184 || Loss: 0.07445875555276871
Epoch: 3/5 || Iter: 30/184 || Loss: 0.08721039444208145
Epoch: 3/5 || Iter: 40/184 || Loss: 0.08603931963443756
Epoch: 3/5 || Iter: 50/184 || Loss: 0.1051332876086235
Epoch: 3/5 || Iter: 60/184 || Loss: 0.07602953910827637
Epoch: 3/5 || Iter: 70/184 || Loss: 0.0799580067396164
Epoch: 3/5 || Iter: 80/184 || Loss: 0.2258630096912384
Epoch: 3/5 || Iter: 90/184 || Loss: 0.06571513414382935
Epoch: 3/5 || Iter: 100/184 || Loss: 0.0787486881017685
Epoch: 3/5 || Iter: 110/184 || Loss: 0.08627845346927643
Epoch: 3/5 || Iter: 120/184 || Loss: 0.044554226100444794
Epoch: 3/5 || Iter: 130/184 || Loss: 0.07834852486848831
Epoch: 3/5 || Iter: 140/184 || Loss: 0.11499042063951492
Epoch: 3/5 || Iter: 150/184 || Loss: 0.06482337415218353
Epoch: 3/5 || Iter: 160/184 || Loss: 0.09408998489379883
Epoch: 3/5 || Iter: 170/184 || Loss: 0.09243714064359665
Epoch: 3/5 || Iter: 180/184 || Loss: 0.0742005854845047
Epoch: 4/5 || Iter: 0/184 || Loss: 0.0555943064391613
Epoch: 4/5 || Iter: 10/184 || Loss: 0.07216621190309525
Epoch: 4/5 || Iter: 20/184 || Loss: 0.07730989903211594
Epoch: 4/5 || Iter: 30/184 || Loss: 0.14474773406982422
Epoch: 4/5 || Iter: 40/184 || Loss: 0.06566669791936874
Epoch: 4/5 || Iter: 50/184 || Loss: 0.07082650810480118
Epoch: 4/5 || Iter: 60/184 || Loss: 0.07712826132774353
Epoch: 4/5 || Iter: 70/184 || Loss: 0.08188728988170624
Epoch: 4/5 || Iter: 80/184 || Loss: 0.05112256482243538
Epoch: 4/5 || Iter: 90/184 || Loss: 0.07622982561588287
Epoch: 4/5 || Iter: 100/184 || Loss: 0.05635732039809227
Epoch: 4/5 || Iter: 110/184 || Loss: 0.08246149122714996
Epoch: 4/5 || Iter: 120/184 || Loss: 0.06526417285203934
Epoch: 4/5 || Iter: 130/184 || Loss: 0.07732726633548737
Epoch: 4/5 || Iter: 140/184 || Loss: 0.05847015231847763
Epoch: 4/5 || Iter: 150/184 || Loss: 0.1348450481891632
Epoch: 4/5 || Iter: 160/184 || Loss: 0.20238502323627472
Epoch: 4/5 || Iter: 170/184 || Loss: 0.06471627205610275
Epoch: 4/5 || Iter: 180/184 || Loss: 0.06523995101451874
Epoch: 5/5 || Iter: 0/184 || Loss: 0.07207831740379333
Epoch: 5/5 || Iter: 10/184 || Loss: 0.0672682523727417
Epoch: 5/5 || Iter: 20/184 || Loss: 0.08528759330511093
Epoch: 5/5 || Iter: 30/184 || Loss: 0.11275054514408112
Epoch: 5/5 || Iter: 40/184 || Loss: 0.07174934446811676
Epoch: 5/5 || Iter: 50/184 || Loss: 0.04678089916706085
Epoch: 5/5 || Iter: 60/184 || Loss: 0.07671410590410233
Epoch: 5/5 || Iter: 70/184 || Loss: 0.09693682193756104
Epoch: 5/5 || Iter: 80/184 || Loss: 0.061378393322229385
Epoch: 5/5 || Iter: 90/184 || Loss: 0.06906907260417938
Epoch: 5/5 || Iter: 100/184 || Loss: 0.06854525208473206
Epoch: 5/5 || Iter: 110/184 || Loss: 0.050054144114255905
Epoch: 5/5 || Iter: 120/184 || Loss: 0.07530767470598221
Epoch: 5/5 || Iter: 130/184 || Loss: 0.04739246517419815
Epoch: 5/5 || Iter: 140/184 || Loss: 0.0572473481297493
Epoch: 5/5 || Iter: 150/184 || Loss: 0.10901613533496857
Epoch: 5/5 || Iter: 160/184 || Loss: 0.07046938687562943
Epoch: 5/5 || Iter: 170/184 || Loss: 0.07678557932376862
Epoch: 5/5 || Iter: 180/184 || Loss: 0.08337099850177765

Протестируйте обе модели, сравните метрики:

In [21]:
net.load_state_dict(torch.load('./CELoss.pth_5.pth'))
net.eval()
Out[21]:
PSPNet(
  (encoder): EncoderBlock(
    (encoder_main): Sequential(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (encoder_supp): Sequential(
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
  )
  (decoder): DecoderBlock(
    (PPM): PyramidPoolingModule(
      (bins): ModuleList(
        (0): Sequential(
          (0): AdaptiveAvgPool2d(output_size=1)
          (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
          (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
        )
        (1): Sequential(
          (0): AdaptiveAvgPool2d(output_size=2)
          (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
          (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
        )
        (2): Sequential(
          (0): AdaptiveAvgPool2d(output_size=3)
          (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
          (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
        )
        (3): Sequential(
          (0): AdaptiveAvgPool2d(output_size=6)
          (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
          (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
        )
      )
    )
    (SM): SupplementaryModule(
      (suppl): Sequential(
        (0): Conv2d(2048, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Dropout2d(p=0.1, inplace=False)
        (4): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
        (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
      )
    )
    (UM): UpsampleModule(
      (upsample): Sequential(
        (0): Upsample(
          (us_transform): Sequential(
            (0): Conv2d(548, 274, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(274, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
        (1): Upsample(
          (us_transform): Sequential(
            (0): Conv2d(274, 137, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(137, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
        (2): Upsample(
          (us_transform): Sequential(
            (0): Conv2d(137, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
      )
    )
  )
  (head): SegmentationHead(
    (segmentation_head): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Dropout2d(p=0.0, inplace=False)
      (4): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)
In [26]:
dl_prediction, dl_target = net.test_model(test_dataloader)
In [ ]:
dl_prediction, dl_target = net.test_model(test_dataloader)
In [27]:
print("Mean IoU metric: ", metric_class.IoUMetric(dl_prediction, dl_target))
print("Mean Recall metric: ", metric_class.RecallMetric(dl_prediction, dl_target))
Mean IoU metric:  tensor(0.9417)
Mean Recall metric:  tensor(0.9694)

Примеры работы вами обученной сети:

In [31]:
img_idx = np.random.randint(0, 100)
for idx, (input, target) in enumerate(test_dataloader):
    if (idx < img_idx):
        continue
    draw((input[0].squeeze(), target[0].squeeze()), t_dict, dl_prediction[8*idx])
    plt.pause(0.1)
    if (idx == img_idx+2):
        break

Ответьте на вопрос: Как выбор функции потерь влияет на рассчитываемые метрики в ходе обучения?

Ваш ответ: с CrossEntropyLoss метрики в ходе обучения (IOU, Recall) лучше, чем с FocalLoss. Тут нет какого-то строгого обоснования почему, ведь вроде FocalLoss теоретически должен помогать при дисбалансе (а у нас есть дисбаланс пикселей внутри изображения). По всей видимости с нашими данными и нашей моделью этот дисбаланс не настолько существенен, чтобы FocalLoss получало преимущество.

3.3 Бонусное задание: Реализация и обучение двуглавой сети (3 балла)¶

До этого момента мы ни разу не использовали тот факт, что в нашем датасете не бывает слуаев, в которых и собака, и кошка одновременно находятся в кадре. В это же время блок SegmentationHead допускает этот случай, что дает теоретическую возможность модели ошибиться. Чтобы повысить устойчивость модели мы будем использовать две головы: голова двухклассовой сегментации, которая сегментирует животное на изображении, а вторая голова бинарной классификации будет предсказывать, что это за животное (собака или кошка). Таким образом, наша модель не имеет возможности отнести голову животного к классу "собака", а туловище к классу "кошка", что увеличивает ее устойчивость. Реализуйте двуглавый блок SegmentationClassificationHeads.

In [48]:
class SegmentationClassificationHeads(nn.Module):
    def __init__(self, in_channels: int, num_classes: int, dropout: float = 0.1):
        """
        Вычисляет score для каждого из классов
        Вход: тензор (Batch_size, `in_channels`, Height, Width)
        
        Выход: тензор (Batch_size, `num_classes`, Height, Width)
        """
        super().__init__()
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels // 2),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout),
            nn.Conv2d(in_channels // 2, num_classes, kernel_size=1),
        )
        
        self.classification_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32, num_classes - 1),
            nn.Softmax(dim=1)
        )
        
    def combine_heads(self, seg_pred: torch.tensor, cls_pred: torch.tensor) -> torch.tensor:
        """
        ==== YOUR CODE =====
             ¯\_(ツ)_/¯
        """
        labels_pred = cls_pred.argmax(axis=1)
        mask = torch.zeros_like(seg_pred)
        mask[:, labels_pred, :, :] = 1

        return seg_pred * mask

    def forward(self, x: torch.tensor, x_supp: torch.tensor) -> torch.tensor:
        """
        Вот мы и воспользовались ранее фиктивным аргументом `x_supp`
        """
        cls_pred = self.classification_head(x_supp)
        seg_pred = self.segmentation_head(x)
        return self.combine_heads(seg_pred, cls_pred)

Обучите двуглавую сеть и получите улучшение метрик относительно наилучшего результата предыдущего пункта:

  • Mean IoU metric > 0.93
  • Mean Recall metric > 0.96

К вашему решению требуется прикрепить логи tensorboard.

In [55]:
train_config["writer"] = SummaryWriter(comment="TwoHead_CEloss") #TwoHead_Floss
train_config["save_model_path"] = 'SegClfHead.pth'
train_config['loss_fn'] = metric_class.CrossEntropyLoss

net = PSPNet(pretrained_model, SegmentationClassificationHeads, num_classes=3, train_config=train_config).to(DEVICE)
print("#параметров в сети:", count_parameters(net))
Pretrained main module conv1 is loaded
Pretrained main module bn1 is loaded
Pretrained main module relu is loaded
Pretrained main module maxpool is loaded
Pretrained main module layer1 is loaded
Pretrained main module layer2 is loaded
Pretrained supp module layer3 is loaded
Pretrained supp module layer4 is loaded
#параметров в сети: 11759868
In [56]:
net.train_model(train_dataloader_memmap, val_dataloader_memmap)
Epoch: 1/5 || Iter: 0/184 || Loss: 1.006283164024353
Epoch: 1/5 || Iter: 10/184 || Loss: 0.3502597212791443
Epoch: 1/5 || Iter: 20/184 || Loss: 0.3152938485145569
Epoch: 1/5 || Iter: 30/184 || Loss: 0.229232057929039
Epoch: 1/5 || Iter: 40/184 || Loss: 0.32625967264175415
Epoch: 1/5 || Iter: 50/184 || Loss: 0.22314202785491943
Epoch: 1/5 || Iter: 60/184 || Loss: 0.29661625623703003
Epoch: 1/5 || Iter: 70/184 || Loss: 0.17813940346240997
Epoch: 1/5 || Iter: 80/184 || Loss: 0.22324040532112122
Epoch: 1/5 || Iter: 90/184 || Loss: 0.1593639850616455
Epoch: 1/5 || Iter: 100/184 || Loss: 0.2098541557788849
Epoch: 1/5 || Iter: 110/184 || Loss: 0.17502334713935852
Epoch: 1/5 || Iter: 120/184 || Loss: 0.19200189411640167
Epoch: 1/5 || Iter: 130/184 || Loss: 0.1552913933992386
Epoch: 1/5 || Iter: 140/184 || Loss: 0.21559476852416992
Epoch: 1/5 || Iter: 150/184 || Loss: 0.3583393692970276
Epoch: 1/5 || Iter: 160/184 || Loss: 0.17660365998744965
Epoch: 1/5 || Iter: 170/184 || Loss: 0.190147265791893
Epoch: 1/5 || Iter: 180/184 || Loss: 0.1797989159822464
Epoch: 2/5 || Iter: 0/184 || Loss: 0.17468833923339844
Epoch: 2/5 || Iter: 10/184 || Loss: 0.13037434220314026
Epoch: 2/5 || Iter: 20/184 || Loss: 0.194634348154068
Epoch: 2/5 || Iter: 30/184 || Loss: 0.16901983320713043
Epoch: 2/5 || Iter: 40/184 || Loss: 0.23156394064426422
Epoch: 2/5 || Iter: 50/184 || Loss: 0.17543001472949982
Epoch: 2/5 || Iter: 60/184 || Loss: 0.21650846302509308
Epoch: 2/5 || Iter: 70/184 || Loss: 0.1911405324935913
Epoch: 2/5 || Iter: 80/184 || Loss: 0.25600340962409973
Epoch: 2/5 || Iter: 90/184 || Loss: 0.1566762626171112
Epoch: 2/5 || Iter: 100/184 || Loss: 0.15329773724079132
Epoch: 2/5 || Iter: 110/184 || Loss: 0.15566708147525787
Epoch: 2/5 || Iter: 120/184 || Loss: 0.10658662021160126
Epoch: 2/5 || Iter: 130/184 || Loss: 0.14871226251125336
Epoch: 2/5 || Iter: 140/184 || Loss: 0.3607293963432312
Epoch: 2/5 || Iter: 150/184 || Loss: 0.1413830667734146
Epoch: 2/5 || Iter: 160/184 || Loss: 0.30521160364151
Epoch: 2/5 || Iter: 170/184 || Loss: 0.2131950557231903
Epoch: 2/5 || Iter: 180/184 || Loss: 0.1426122933626175
Epoch: 3/5 || Iter: 0/184 || Loss: 0.15959089994430542
Epoch: 3/5 || Iter: 10/184 || Loss: 0.17366008460521698
Epoch: 3/5 || Iter: 20/184 || Loss: 0.21151840686798096
Epoch: 3/5 || Iter: 30/184 || Loss: 0.15909597277641296
Epoch: 3/5 || Iter: 40/184 || Loss: 0.2074342519044876
Epoch: 3/5 || Iter: 50/184 || Loss: 0.13781821727752686
Epoch: 3/5 || Iter: 60/184 || Loss: 0.14632698893547058
Epoch: 3/5 || Iter: 70/184 || Loss: 0.14367517828941345
Epoch: 3/5 || Iter: 80/184 || Loss: 0.12310225516557693
Epoch: 3/5 || Iter: 90/184 || Loss: 0.15726590156555176
Epoch: 3/5 || Iter: 100/184 || Loss: 0.19093354046344757
Epoch: 3/5 || Iter: 110/184 || Loss: 0.13513793051242828
Epoch: 3/5 || Iter: 120/184 || Loss: 0.32211387157440186
Epoch: 3/5 || Iter: 130/184 || Loss: 0.11975201219320297
Epoch: 3/5 || Iter: 140/184 || Loss: 0.17612037062644958
Epoch: 3/5 || Iter: 150/184 || Loss: 0.11519519239664078
Epoch: 3/5 || Iter: 160/184 || Loss: 0.1660498082637787
Epoch: 3/5 || Iter: 170/184 || Loss: 0.15760758519172668
Epoch: 3/5 || Iter: 180/184 || Loss: 0.16032609343528748
Epoch: 4/5 || Iter: 0/184 || Loss: 0.11131583899259567
Epoch: 4/5 || Iter: 10/184 || Loss: 0.3819928765296936
Epoch: 4/5 || Iter: 20/184 || Loss: 0.0925675481557846
Epoch: 4/5 || Iter: 30/184 || Loss: 0.1356377899646759
Epoch: 4/5 || Iter: 40/184 || Loss: 0.18674714863300323
Epoch: 4/5 || Iter: 50/184 || Loss: 0.10063348710536957
Epoch: 4/5 || Iter: 60/184 || Loss: 0.13070477545261383
Epoch: 4/5 || Iter: 70/184 || Loss: 0.18780240416526794
Epoch: 4/5 || Iter: 80/184 || Loss: 0.1717657893896103
Epoch: 4/5 || Iter: 90/184 || Loss: 0.16520006954669952
Epoch: 4/5 || Iter: 100/184 || Loss: 0.14093822240829468
Epoch: 4/5 || Iter: 110/184 || Loss: 0.10212830454111099
Epoch: 4/5 || Iter: 120/184 || Loss: 0.17921021580696106
Epoch: 4/5 || Iter: 130/184 || Loss: 0.12966689467430115
Epoch: 4/5 || Iter: 140/184 || Loss: 0.10844019055366516
Epoch: 4/5 || Iter: 150/184 || Loss: 0.1326400190591812
Epoch: 4/5 || Iter: 160/184 || Loss: 0.17182667553424835
Epoch: 4/5 || Iter: 170/184 || Loss: 0.2265619933605194
Epoch: 4/5 || Iter: 180/184 || Loss: 0.11308327317237854
Epoch: 5/5 || Iter: 0/184 || Loss: 0.16694375872612
Epoch: 5/5 || Iter: 10/184 || Loss: 0.1647510677576065
Epoch: 5/5 || Iter: 20/184 || Loss: 0.13737812638282776
Epoch: 5/5 || Iter: 30/184 || Loss: 0.1677914559841156
Epoch: 5/5 || Iter: 40/184 || Loss: 0.11636580526828766
Epoch: 5/5 || Iter: 50/184 || Loss: 0.11996309459209442
Epoch: 5/5 || Iter: 60/184 || Loss: 0.11455247551202774
Epoch: 5/5 || Iter: 70/184 || Loss: 0.1332285851240158
Epoch: 5/5 || Iter: 80/184 || Loss: 0.1587146818637848
Epoch: 5/5 || Iter: 90/184 || Loss: 0.11577418446540833
Epoch: 5/5 || Iter: 100/184 || Loss: 0.19084015488624573
Epoch: 5/5 || Iter: 110/184 || Loss: 0.1464766561985016
Epoch: 5/5 || Iter: 120/184 || Loss: 0.11565650254487991
Epoch: 5/5 || Iter: 130/184 || Loss: 0.11677568405866623
Epoch: 5/5 || Iter: 140/184 || Loss: 0.13255789875984192
Epoch: 5/5 || Iter: 150/184 || Loss: 0.10625714808702469
Epoch: 5/5 || Iter: 160/184 || Loss: 0.13042517006397247
Epoch: 5/5 || Iter: 170/184 || Loss: 0.13463018834590912
Epoch: 5/5 || Iter: 180/184 || Loss: 0.1401868462562561

Тестируем модель:

In [61]:
net.load_state_dict(torch.load('./SegClfHead.pth_5.pth'))
net.eval()
Out[61]:
PSPNet(
  (encoder): EncoderBlock(
    (encoder_main): Sequential(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (encoder_supp): Sequential(
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
  )
  (decoder): DecoderBlock(
    (PPM): PyramidPoolingModule(
      (bins): ModuleList(
        (0): Sequential(
          (0): AdaptiveAvgPool2d(output_size=1)
          (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
          (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
        )
        (1): Sequential(
          (0): AdaptiveAvgPool2d(output_size=2)
          (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
          (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
        )
        (2): Sequential(
          (0): AdaptiveAvgPool2d(output_size=3)
          (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
          (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
        )
        (3): Sequential(
          (0): AdaptiveAvgPool2d(output_size=6)
          (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
          (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
        )
      )
    )
    (SM): SupplementaryModule(
      (suppl): Sequential(
        (0): Conv2d(2048, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Dropout2d(p=0.1, inplace=False)
        (4): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
        (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
      )
    )
    (UM): UpsampleModule(
      (upsample): Sequential(
        (0): Upsample(
          (us_transform): Sequential(
            (0): Conv2d(548, 274, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(274, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
        (1): Upsample(
          (us_transform): Sequential(
            (0): Conv2d(274, 137, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(137, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
        (2): Upsample(
          (us_transform): Sequential(
            (0): Conv2d(137, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
      )
    )
  )
  (head): SegmentationClassificationHeads(
    (segmentation_head): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Dropout2d(p=0.1, inplace=False)
      (4): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))
    )
    (classification_head): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=32, out_features=2, bias=True)
      (2): Softmax(dim=1)
    )
  )
)
In [62]:
dl_prediction, dl_target = net.test_model(test_dataloader)
In [64]:
print("Mean IoU metric: ", metric_class.IoUMetric(dl_prediction, dl_target))
print("Mean Recall metric: ", metric_class.RecallMetric(dl_prediction, dl_target))
Mean IoU metric:  tensor(0.9452)
Mean Recall metric:  tensor(0.9723)

Примеры работы вами обученной двуглавой сети:

In [65]:
img_idx = np.random.randint(0, 100)
for idx, (input, target) in enumerate(test_dataloader):
    if (idx < img_idx):
        continue
    draw((input[0].squeeze(), target[0].squeeze()), t_dict, dl_prediction[8*idx])
    plt.pause(0.1)
    if (idx == img_idx+2):
        break
In [ ]: